Spectral Clustering
Overview
With the data ingestion and preprocessing foundations completed, the current notebook will demonstrate a simple machine learning workflow to identify water in our satellite images. For this particular approach, we will utilize spectral clustering to assign labels to each x,y point in our data space based on the similarity of the combined set of pixels across wavelength-bands in our image stack. Our example approach uses a version of spectral clustering from dask_ml that is a scalable equivalent of what is available in scikit-learn. To focus on the analysis, we will begin by performing this analysis on a single image and then conclude by comparing across images by combining our regridding steps from the previous notebook with spectral clustering.
Our present approach is just one example of an analysis, but any library, algorithm, or simulator could be used at this stage if it can accept our processed array data.
Prerequisites
Concepts |
Importance |
Notes |
|---|---|---|
Necessary |
||
Helpful |
spectral clustering at scale |
|
Helpful |
spectral clustering |
Time to learn: 20 minutes.
Imports
import intake
import numpy as np
import xarray as xr
from dask_ml.cluster import SpectralClustering
from dask.distributed import Client
import cartopy.crs as ccrs
import geoviews as gv
import hvplot.xarray
import warnings
# Ignore a warning about the format of epsg codes
warnings.simplefilter('ignore', FutureWarning)
Loading data
Let’s start by loading the small version of the landsat data. This should be familiar from the previous notebooks.
cat = intake.open_catalog('./data/catalog.yml')
landsat_5_da = cat.landsat_5_small.to_dask()
landsat_5_da
<xarray.DataArray (band: 6, y: 300, x: 300)>
dask.array<concatenate, shape=(6, 300, 300), dtype=float64, chunksize=(1, 50, 50), chunktype=numpy.ndarray>
Coordinates:
* band (band) int64 1 2 3 4 5 7
* y (y) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
* x (x) float64 3.324e+05 3.326e+05 3.327e+05 ... 3.771e+05 3.772e+05
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaReshaping Data
The shape of our data is currently n_bands, n_y, n_x. In order for dask-ml / scikit-learn to consume our data, we’ll need to reshape our image stacks into n_samples, n_features, where n_features is the number of wavelength-bands and n_samples is the total number of pixels in each wavelength-band image. Essentially, we’ll be creating a vector of pixels out of each image, where each pixel has multiple features (bands), but the ordering of the pixels is no longer relevant to the computation. We’ll first look at using NumPy, then Xarray.
Numpy
Data can be reshaped at the lowest level using NumPy, by getting the underlying values from the xarray.DataArray, and using flatten and transpose to get the right shape.
arr = landsat_5_da.values
arr.shape
(6, 300, 300)
flattened_npa = np.array([arr[i].flatten() for i in range(arr.shape[0])])
flattened_npa
array([[ 640., 842., 864., ..., 1309., 1636., 1199.],
[ 810., 1096., 1191., ..., 1736., 2250., 1736.],
[1007., 1345., 1471., ..., 2202., 2783., 1994.],
[1221., 1662., 1809., ..., 2755., 3431., 2223.],
[1819., 2596., 2495., ..., 3067., 3802., 2665.],
[1682., 2215., 2070., ..., 2860., 3724., 2333.]])
flattened_npa.shape
(6, 90000)
flattened_t_npa = flattened_npa.transpose()
flattened_t_npa.shape
(90000, 6)
Now we have the data in n_samples, n_features, but since these are bare NumPy arrays without any coordinates or labeled dimensions, it will be harder to recreate the images after the analysis.
Xarray
Let’s consider a better way to reshape the data that preserves the metadata. By using xarray methods to flatten the data, we can keep track of the coordinate labels ‘x’ and ‘y’ along the way. This means that we have the ability to reshape back to our original array at any time with no information loss!
flattened_xda = landsat_5_da.stack(z=('x','y'))
flattened_xda
<xarray.DataArray (band: 6, z: 90000)>
dask.array<reshape, shape=(6, 90000), dtype=float64, chunksize=(1, 3000), chunktype=numpy.ndarray>
Coordinates:
* band (band) int64 1 2 3 4 5 7
* z (z) object MultiIndex
* x (z) float64 3.324e+05 3.324e+05 3.324e+05 ... 3.772e+05 3.772e+05
* y (z) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaWe can reorder the dimensions using DataArray.transpose:
flattened_t_xda = flattened_xda.transpose('z', 'band')
flattened_t_xda
<xarray.DataArray (z: 90000, band: 6)>
dask.array<transpose, shape=(90000, 6), dtype=float64, chunksize=(3000, 1), chunktype=numpy.ndarray>
Coordinates:
* band (band) int64 1 2 3 4 5 7
* z (z) object MultiIndex
* x (z) float64 3.324e+05 3.324e+05 3.324e+05 ... 3.772e+05 3.772e+05
* y (z) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaStandardize Data
Now that we have the data in the correct shape, let’s standardize (or rescale) the values of the data. We do this to get all the flattened image vectors onto a common scale while preserving the differences in the ranges of values. Again, we’ll demonstrate doing this first in NumPy and then xarray.
# TODO: introduce standardization equation
rescaled_npa = (flattened_t_npa - flattened_t_npa.mean()) / flattened_t_npa.std()
rescaled_npa
array([[-1.29960701, -1.10062865, -0.87004784, -0.6195692 , 0.08036645,
-0.0799867 ],
[-1.0631739 , -0.76587681, -0.47443204, -0.10339592, 0.98981461,
0.54386898],
[-1.03742375, -0.65468302, -0.32695396, 0.06866184, 0.87159805,
0.37415215],
...,
[-0.51656863, -0.01678181, 0.52865299, 1.1759179 , 1.54110171,
1.2988163 ],
[-0.1338279 , 0.58483512, 1.2086908 , 1.9671495 , 2.40139051,
2.31009455],
[-0.64531934, -0.01678181, 0.28519712, 0.55323267, 1.07057641,
0.68198338]])
with xr.set_options(keep_attrs=True):
rescaled_xda = (flattened_t_xda - flattened_t_xda.mean()) / flattened_t_xda.std()
rescaled_xda
<xarray.DataArray (z: 90000, band: 6)>
dask.array<truediv, shape=(90000, 6), dtype=float64, chunksize=(3000, 1), chunktype=numpy.ndarray>
Coordinates:
* band (band) int64 1 2 3 4 5 7
* z (z) object MultiIndex
* x (z) float64 3.324e+05 3.324e+05 3.324e+05 ... 3.772e+05 3.772e+05
* y (z) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaInfo
Just a reminder - above, we are using a context manager “with xr.set_options(keep_attrs=True):” to retain the array’s attributes through the operations. That is, we want all the metadata like ‘crs’ to stay with our result so we can use ‘geo=True’ in our plotting.
As rescaled_xda is still a Dask object, if you wanted to actually run the rescaling at this point (provided that all the data can fit into memory), use .compute()
rescaled_xda.compute()
<xarray.DataArray (z: 90000, band: 6)>
array([[-1.29960701, -1.10062865, -0.87004784, -0.6195692 , 0.08036645,
-0.0799867 ],
[-1.1170151 , -0.76587681, -0.57392122, -0.21810109, 0.59536927,
0.14708272],
[-0.90750259, -0.54348923, -0.22863524, 0.29690172, 1.3046686 ,
0.77093841],
...,
[-1.05966251, -0.78694511, -0.59498952, -0.52593232, -0.41590899,
-0.7213993 ],
[-1.05966251, -0.78811557, -0.69330824, -0.52593232, -0.29769244,
-0.4966708 ],
[-0.64531934, -0.01678181, 0.28519712, 0.55323267, 1.07057641,
0.68198338]])
Coordinates:
* band (band) int64 1 2 3 4 5 7
* z (z) object MultiIndex
* x (z) float64 3.324e+05 3.324e+05 3.324e+05 ... 3.772e+05 3.772e+05
* y (z) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaML pipeline
Now that our data is in the proper shape and value range, we are ready to conduct spectral clustering. Here we will use a version of spectral clustering from dask_ml that is a scalable equivalent to operations from Scikit-learn that cluster pixels based on similarity (across all bands, which makes it spectral clustering by spectra!)
The Machine Learning pipeline shown below is just for demonstration purposes, including the shaping/reshaping of data. In practice you will likely be using a more sophisticated pipeline.
client = Client(processes=False)
client
Client
Client-ac14fb54-9b80-11ed-8b4b-6045bdb91274
| Connection method: Cluster object | Cluster type: distributed.LocalCluster |
| Dashboard: http://10.1.0.232:8787/status |
Cluster Info
LocalCluster
c8a88981
| Dashboard: http://10.1.0.232:8787/status | Workers: 1 |
| Total threads: 2 | Total memory: 6.78 GiB |
| Status: running | Using processes: False |
Scheduler Info
Scheduler
Scheduler-77918144-a4c4-428a-84f3-b3ba6acfa537
| Comm: inproc://10.1.0.232/2891/1 | Workers: 1 |
| Dashboard: http://10.1.0.232:8787/status | Total threads: 2 |
| Started: Just now | Total memory: 6.78 GiB |
Workers
Worker: 0
| Comm: inproc://10.1.0.232/2891/4 | Total threads: 2 |
| Dashboard: http://10.1.0.232:39071/status | Memory: 6.78 GiB |
| Nanny: None | |
| Local directory: /tmp/dask-worker-space/worker-24mhn83m | |
Now we will compute and persist the rescaled data to feed into the ML pipeline. Notice that our X matrix below has the shape: n_samples, n_features as discussed earlier.
X = client.persist(rescaled_xda)
X.shape
(90000, 6)
First we will set up the model with the number of clusters, and other options.
clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
kmeans_params={'init_max_iter': 5},
persist_embedding=True)
This is the slow-ish part. Then we’ll fit the model to our matrix X. This is the part that will take a noticeable amount of time. Depending on your setup, it could take about 30 seconds to run the small version of the data (on a relatively beefy laptop) or around 10 minutes for a full size landsat image.
%time clf.fit(X)
/usr/share/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/dask/base.py:1373: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.
warnings.warn(
CPU times: user 37 s, sys: 3.44 s, total: 40.4 s
Wall time: 47.2 s
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)labels = clf.assign_labels_.labels_.compute()
labels.shape
(90000,)
The result is a vector of cluster labels! OK, I know this doesn’t seem all that exciting yet, but we’re getting there. Next we will reshape the results into human-friendly image form.
labels
array([0, 0, 3, ..., 0, 0, 3], dtype=int32)
Un-flattening
Once the computation is done, the output can be used to create a new array with the same structure as the input array. This new output array will have the coordinates needed to be unstacked similarly to how they were stacked. One of the main benefits of using xarray for this stacking and unstacking is that allows xarray to keep track of the coordinate information for us.
Since the original array is n_samples by n_features (90000, 6) and the output only contains one feature (90000,), the template structure for this data needs to have the shape (n_samples). We achieve this by just taking one of the bands.
template = flattened_t_xda[:, 0]
output_array = template.copy(data=labels)
output_array
<xarray.DataArray (z: 90000)>
array([0, 0, 3, ..., 0, 0, 3], dtype=int32)
Coordinates:
band int64 1
* z (z) object MultiIndex
* x (z) float64 3.324e+05 3.324e+05 3.324e+05 ... 3.772e+05 3.772e+05
* y (z) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaWith this new output array in hand, we can unstack back to the original dimensions:
unstacked = output_array.unstack()
unstacked
<xarray.DataArray (x: 300, y: 300)>
array([[0, 0, 3, ..., 3, 3, 2],
[0, 0, 0, ..., 3, 3, 3],
[0, 0, 3, ..., 3, 3, 3],
...,
[3, 3, 3, ..., 3, 0, 2],
[0, 0, 0, ..., 0, 0, 2],
[0, 0, 0, ..., 0, 0, 3]], dtype=int32)
Coordinates:
* x (x) float64 3.324e+05 3.326e+05 3.327e+05 ... 3.771e+05 3.772e+05
* y (y) float64 4.309e+06 4.309e+06 4.309e+06 ... 4.264e+06 4.264e+06
band int64 1
Attributes:
transform: (150.0, 0.0, 332325.0, 0.0, -150.0, 4309275.0)
crs: +init=epsg:32611
res: (150.0, 150.0)
is_tiled: 0
nodatavals: (nan,)
scales: (1.0,)
offsets: (0.0,)
AREA_OR_POINT: AreaAnd finally, bring the results to life!
landsat_5_da.sel(band=4).hvplot.image(x='x', y='y', geo=True, datashade=True, cmap='greys', title='Raw Image') + \
unstacked.hvplot(x='x', y='y', cmap='Set3', geo=True, colorbar=False, title='Spectral Clustering Labels')
Spectral Clustering over time
Now that we have conducted the spectral clustering for one time, let’s bring it together with what we learned about regridding in the previous Preprocessing notebook to compare the results of this analysis from two different time points. The important conceptual goal here is to get the images from different acquisitions onto the same spatial grid so that we can have a chance to run computations that directly compare the images.
We already have Landsat 5 data (from 1988), so let’s just load Landsat 8 (from 2017).
landsat_8_da = cat.landsat_8_small.read_chunked()
See the previous preprocessing notebook for a detailed walkthrough on the following steps, but in summary, we are creating a bounding box and grid around our region of interest and then interpolating our data onto this new grid.
crs = ccrs.epsg(32611)
x_center, y_center = crs.transform_point(-118.7081, 38.6942, ccrs.PlateCarree())
buffer = 1.5e4
xmin = x_center - buffer
xmax = x_center + buffer
ymin = y_center - buffer
ymax = y_center + buffer
bounding_box = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
res = 200
x = np.arange(xmin, xmax, res)
y = np.arange(ymin, ymax, res)
landsat_8_da_regridded = landsat_8_da.interp(x=x, y=y)
landsat_5_da_regridded = landsat_5_da.interp(x=x, y=y)
Let’s take a look at our regridded data. Notice that hvPlot understands that the two arrays have a common dimension band, and automatically link them to the same widget.
landsat_8_da_regridded.hvplot.image(x='x', y='y', geo=True, title='Landsat 8 2017', colorbar=False, rasterize=True, cmap='viridis') +\
landsat_5_da_regridded.hvplot.image(x='x', y='y', geo=True, title='Landsat 5 1988', colorbar=False, rasterize=True, cmap='viridis')
Now let’s run the same spectral clustering steps that we saw earlier, but on this new regridded data. Again, we will start with reshaping and rescaling the data.
l5_rg_flat_xda = landsat_5_da_regridded.stack(z=('x','y')).transpose('z', 'band')
l8_rg_flat_xda = landsat_8_da_regridded.stack(z=('x','y')).transpose('z', 'band')
l5_rg_rescale_xda = (l5_rg_flat_xda - l5_rg_flat_xda.mean()) / l5_rg_flat_xda.std()
l8_rg_rescale_xda = (l8_rg_flat_xda - l8_rg_flat_xda.mean()) / l8_rg_flat_xda.std()
l5_X = client.persist(l5_rg_rescale_xda)
l8_X = client.persist(l8_rg_rescale_xda)
And now we fit the data to our model.
l5_clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
kmeans_params={'init_max_iter': 5},
persist_embedding=True)
%time l5_clf.fit(l5_X)
/usr/share/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/dask/base.py:1373: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.
warnings.warn(
CPU times: user 34.4 s, sys: 2.28 s, total: 36.7 s
Wall time: 44.9 s
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)l8_clf = SpectralClustering(n_clusters=4, random_state=0, gamma=None,
kmeans_params={'init_max_iter': 5},
persist_embedding=True)
%time l8_clf.fit(l8_X)
/usr/share/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/dask/base.py:1373: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.
warnings.warn(
CPU times: user 28 s, sys: 1.87 s, total: 29.9 s
Wall time: 37.1 s
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4,
persist_embedding=True, random_state=0)l5_labels = l5_clf.assign_labels_.labels_.compute()
l8_labels = l8_clf.assign_labels_.labels_.compute()
And the last step before the big reveal is to reshape the results back into image form:
l5_template = l5_rg_flat_xda[:, 0]
l5_output_array = l5_template.copy(data=l5_labels)
l8_template = l8_rg_flat_xda[:, 0]
l8_output_array = l8_template.copy(data=l8_labels)
l5_labels_unstacked = l5_output_array.unstack()
l8_labels_unstacked = l8_output_array.unstack()
Ta-da!
l5_labels_unstacked.hvplot(x='x', y='y', width=400, height=400, cmap='Set3', geo=True, colorbar=False, title='1988 Labels') +\
l8_labels_unstacked.hvplot(x='x', y='y', width=400, height=400, cmap='Set3', geo=True, colorbar=False, title='2017 Labels')
But wait, the spectral clustering labels of water are clearly different between the two years. If we want to directly compare the amount of water across these images, we’ll have to create a mask using the appropriate label from each image that is indicative of water. Since we are using interactive plotting, we can just hover over the lake in these images to discover that we are interested in cluster label 1 (blue) for the 1988 data and cluster label 3 (yellow) for the 2017 data. Great, now let’s create those water masks.
l5_labels_mask = l5_labels_unstacked.where(l5_labels_unstacked == 1, 0) # set non-1 to 0
l8_labels_mask = l8_labels_unstacked.where(l8_labels_unstacked == 3, 0) # set non-3 to 0
l8_labels_mask = l8_labels_mask.where(l8_labels_mask != 3, 1) # set 3 -> 1
l5_labels_mask.hvplot(x='x', y='y', cmap='greys', geo=True, colorbar=False, title='1988 Water Mask') +\
l8_labels_mask.hvplot(x='x', y='y', cmap='greys', geo=True, colorbar=False, title='2017 Water Mask')
Now we can take the difference of these water label masks to see exactly where the water levels has changed.
with xr.set_options(keep_attrs=True):
l8_l5_specdiff = l8_labels_mask - l5_labels_mask
l8_l5_specdiff.hvplot(x='x', y='y', width=400, height=400, cmap='blues', geo=True, alpha=.7, colorbar=False, title='2017-1988 Labels', tiles='ESRI')
We did it! Above, the white pixels are regions where there was water in 1988 but not 2017 around the lake.
Summary
In this notebook we covered reshaping and rescaling the data to get it into a format ready for machine learning. Then we conducted spectral clustering to get label-images of spots where there was likely water, and finally used our regridding approach to compared the water regions from different time points.